package clear.experiment;

import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;

import clear.dep.DepNode;
import clear.dep.DepTree;
import clear.dep.srl.SRLHead;
import clear.dep.srl.SRLInfo;
import clear.reader.SRLReader;
import clear.util.cluster.Prob2dMap;
import clear.util.tuple.JObjectDoubleTuple;

public class SRLTopicCluster
{
	HashMap<String, Prob2dMap> m_ta, m_at;
	HashMap<String, HashSet<String>> s_verbs;
	
	public SRLTopicCluster()
	{
		m_ta = new HashMap<String, Prob2dMap>();
		m_at = new HashMap<String, Prob2dMap>();
		
		s_verbs = new HashMap<String, HashSet<String>>();
	}
	
	public void retrieveTopics(DepTree tree)
	{
		DepNode node, pred;
		SRLInfo info;
		String  feat;
		Prob2dMap pTA, pAT;
		
		for (int i=1; i<tree.size(); i++)
		{
			node = tree.get(i);
			info = node.srlInfo;
			if (!node.isPosx("NN.*"))	continue;
			
			for (SRLHead head : info.heads)
			{
				pred = tree.get(head.headId);
				if ((feat = pred.getFeat("ct")) == null)	continue;
				
				pTA = getSubMap(m_ta, feat);
				pTA.increment(head.label, node.lemma);
				
				pAT = getSubMap(m_at, feat);
				pAT.increment(node.lemma, head.label);
				
				getSubSet(s_verbs, feat).add(pred.lemma);
			}
		}
	}
	
	public void getTopics(ArrayList<HashSet<String>> aTopics, String argKey, double threshold, int num)
	{
		ArrayList<JObjectDoubleTuple<String>> aTA;
		Prob2dMap pTA, pAT;
		HashSet<String> topics, clone;
		
		outer: for (String id : m_ta.keySet())
		{
			pTA = m_ta.get(id);
			pAT = m_at.get(id);
			if ((aTA = pTA.getProb1dList(argKey)) == null)	continue;
			topics = new HashSet<String>();
			
			for (JObjectDoubleTuple<String> tup : aTA)
			{
				tup.value *= pAT.get1dProb(tup.object, argKey);
				if (tup.value >= threshold)	topics.add(tup.object);
			}

			if (topics.size() >= num)
			{
				for (HashSet<String> pSet : aTopics)
				{
					clone = new HashSet<String>(topics);
					clone.removeAll(pSet);
					
					if (clone.size() < num)	continue outer;
				}
				
				aTopics.add(topics);
			}
		}
	}
	
	private Prob2dMap getSubMap(HashMap<String,Prob2dMap> mTa, String key)
	{
		Prob2dMap submap;
		
		if (mTa.containsKey(key))
		{
			submap = mTa.get(key);
		}
		else
		{
			submap = new Prob2dMap();
			mTa.put(key, submap);
		}
		
		return submap;
	}
	
	private HashSet<String> getSubSet(HashMap<String, HashSet<String>> mTa, String key)
	{
		HashSet<String> subset;
		
		if (mTa.containsKey(key))
		{
			subset = mTa.get(key);
		}
		else
		{
			subset = new HashSet<String>();
			mTa.put(key, subset);
		}
		
		return subset;
	}

	
	static public void main(String[] args)
	{
		String inputFile  = args[0];
		String outputFile = args[1];
		
		ArrayList<HashSet<String>> aTopics = new ArrayList<HashSet<String>>();
		SRLTopicCluster tbuild = new SRLTopicCluster();
		SRLReader reader = new SRLReader(inputFile, true);
		DepTree tree;
		
		while ((tree = reader.nextTree()) != null)
		{
			tbuild.retrieveTopics(tree);
		}
		
		
	//	tbuild.getTopics(aTopics, "A0", 0.005, 10);
		tbuild.getTopics(aTopics, "A1", 0.005, 10);
		
		try
		{
			ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(outputFile));			
			outputStream.writeObject(aTopics);
			outputStream.close();
		}
		catch (IOException e) {e.printStackTrace();}
	}
}
